-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops #163724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
-- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to llvm#163374 (review) Signed-off-by: Abhishek Varma <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit includes the basic infra/utilities to add matchers for Signed-off-by: Abhishek Varma <[email protected]> Patch is 36.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163724.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
+ unsigned totalIterators = iteratorTypes.size();
+ switch (totalIterators) {
+ case 2:
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+ case 4:
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+ case 5:
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+ case 6:
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+ case 7:
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+ case 8:
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+ case 9:
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+ }
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv/pooling*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+ }
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+ int64_t &constantValue) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+ dim = dExpr;
+ constantValue = 1;
+ return true;
+ }
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return false;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <CST_1> +
+/// indexingMaps[n-1].getResult(oDim) * <CST_2>
+/// where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0, c1;
+
+ if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+ isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ } else if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[aIndex].getResult(aDim) ==
+/// indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+ unsigned aDim, unsigned bIndex,
+ unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+ getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+ ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size())
+ return false;
+
+ for (auto [indexingMap, expectedSize] :
+ llvm::zip_equal(indexingMaps, expectedSizes)) {
+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+ if (affineMap.getNumResults() != expectedSize)
+ return false;
+ }
+ return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides,
+ ArrayRef<int64_t> tempDilations,
+ ArrayRef<int64_t> tempStrides) {
+ if (!(dilations && strides))
+ return true;
+ for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+ dilations->push_back(dilation);
+ strides->push_back(stride);
+ }
+ return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1, d2, d3, d8, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAt...
[truncated]
|
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extracting this! Sharing my first set of comments. This is still quite dense, so I've not read everything yet 😅
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for splitting the PR, it is easier to review! I'll take a look at [Utils.cpp] changes once we are aligned on the code structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are getting there 😅
I've started reviewing the utility functions, see my comments inline.
| static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, | ||
| unsigned fDim, unsigned oDim, | ||
| int64_t &dilation, int64_t &stride) { | ||
| unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; | |
| unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIndex = 2; |
Also, instead of using "magic" 0, 1 and 2, could you define some global constants instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do that but for Convolution ops outputMapIndex is indexingMaps.size() - 1 : so it can be either 2 or 4 (in a few Convolution ops' case).
How should I go about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't there always be 3 maps? Input + filter + output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, yes.
But ops like conv_2d_nchw_fchw_q, conv_2d_nhwc_hwcf_q, etc take in 2 more input operands and have a structure similar to (as obtained via -linalg-generalize-named-ops) :-
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
module {
func.func @conv_2d_nhwc_hwcf_q(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: i32, %arg3: i32, %arg4: memref<?x?x?x?xf32>) {
linalg.generic {
indexing_maps = [#map, #map1, #map2, #map2, #map3],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%arg0, %arg1, %arg2, %arg3 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, i32, i32)
outs(%arg4 : memref<?x?x?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %in_1: i32, %in_2: i32, %out: f32):
%0 = arith.sitofp %in_1 : i32 to f32
%1 = arith.subf %in, %0 : f32
%2 = arith.sitofp %in_2 : i32 to f32
%3 = arith.subf %in_0, %2 : f32
%4 = arith.mulf %1, %3 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
}
return
}
}
And thus fetching indexing maps from such operations via op.getIndexingMaps() is going to yield 5 indexing maps.
If we have to use a global constants as you suggested above, then perhaps creating a getNonEmptyIndexingMaps() or something of that sorts would be a better way to go.
Let me know your thoughts.
| /// indexingMaps[1].getResult(fDim) * <CST_1> + | ||
| /// indexingMaps[n-1].getResult(oDim) * <CST_2> | ||
| /// where, CST_1 and CST_2 can be any constant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are CST_1 and CST_2, I don't see those in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So CST_1 and CST_2 would be dilations and strides constant.
I'm checking that using isDimTimesConstantOrDimOnly here :-
llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Lines 351 to 352 in 7b47d9e
| if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && | |
| isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant the comment refers to variables such as CST_1, but in the code there are no such variables. IIUC, CST_1 == c0 and CST_2 == c1? I suggest using consistent names.
Also, what dim expression are we matching here?
/// indexingMaps[0].getResult(iDim) ==
/// indexingMaps[1].getResult(fDim) * <CST_1> +
/// indexingMaps[n-1].getResult(oDim) * <CST_2>Put differently, could you type an example where CST_1 and CST_2 != 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Could you check now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments in this round - slowly returning to this (apologies for the delay - travelling)
Thanks for all the updates so far 🙏🏻
| /// indexingMaps[1].getResult(fDim) * <CST_1> + | ||
| /// indexingMaps[n-1].getResult(oDim) * <CST_2> | ||
| /// where, CST_1 and CST_2 can be any constant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant the comment refers to variables such as CST_1, but in the code there are no such variables. IIUC, CST_1 == c0 and CST_2 == c1? I suggest using consistent names.
Also, what dim expression are we matching here?
/// indexingMaps[0].getResult(iDim) ==
/// indexingMaps[1].getResult(fDim) * <CST_1> +
/// indexingMaps[n-1].getResult(oDim) * <CST_2>Put differently, could you type an example where CST_1 and CST_2 != 1?
-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a
linalg.genericop itidentifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of
linalg-specialize-generic-opspass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for
*conv*/*pooling*ops.-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)
Signed-off-by: Abhishek Varma [email protected]